﻿using System.Collections;
using System.Windows;
using System.Windows.Controls;
using System.Windows.Input;
using System.Windows.Media;
using System.Windows.Media.Media3D;
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
using MouseEventArgs = System.Windows.Input.MouseEventArgs;
using TreeView = System.Windows.Controls.TreeView;

namespace Su.WPF.CustomControl.TreeViewEx
{
    internal class MultiSelectTreeView : TreeView
    {
        public MultiSelectTreeView()
        {
            GotFocus += OnTreeViewItemGotFocus;
            PreviewMouseLeftButtonDown += OnTreeViewItemPreviewMouseDown;
            PreviewMouseLeftButtonUp += OnTreeViewItemPreviewMouseUp;
            // 添加键盘事件处理
            KeyDown += OnTreeViewKeyDown;
        }

        private static TreeViewItem _selectTreeViewItemOnMouseUp;

        public static readonly DependencyProperty IsItemSelectedProperty =
            DependencyProperty.RegisterAttached(
                "IsItemSelected",
                typeof(Boolean),
                typeof(MultiSelectTreeView),
                new PropertyMetadata(false, OnIsItemSelectedPropertyChanged)
            );

        public static bool GetIsItemSelected(TreeViewItem element)
        {
            return (bool)element.GetValue(IsItemSelectedProperty);
        }

        public static void SetIsItemSelected(TreeViewItem element, Boolean value)
        {
            if (element == null)
                return;

            element.SetValue(IsItemSelectedProperty, value);
        }

        private static void OnIsItemSelectedPropertyChanged(
            DependencyObject d,
            DependencyPropertyChangedEventArgs e
        )
        {
            var treeViewItem = d as TreeViewItem;
            var treeView = FindTreeView(treeViewItem);
            if (treeViewItem != null && treeView != null)
            {
                var selectedItems = GetSelectedItems(treeView);
                if (selectedItems != null)
                {
                    if (GetIsItemSelected(treeViewItem))
                    {
                        selectedItems.Add(treeViewItem.Header);
                    }
                    else
                    {
                        selectedItems.Remove(treeViewItem.Header);
                    }
                }
            }
        }

        public static readonly DependencyProperty SelectedItemsProperty =
            DependencyProperty.RegisterAttached(
                "SelectedItems",
                typeof(IList),
                typeof(MultiSelectTreeView)
            );

        public static IList GetSelectedItems(TreeView element)
        {
            return (IList)element.GetValue(SelectedItemsProperty);
        }

        public static void SetSelectedItems(TreeView element, IList value)
        {
            element.SetValue(SelectedItemsProperty, value);
        }

        private static readonly DependencyProperty StartItemProperty =
            DependencyProperty.RegisterAttached(
                "StartItem",
                typeof(TreeViewItem),
                typeof(MultiSelectTreeView)
            );

        private static TreeViewItem GetStartItem(TreeView element)
        {
            return (TreeViewItem)element.GetValue(StartItemProperty);
        }

        private static void SetStartItem(TreeView element, TreeViewItem value)
        {
            element.SetValue(StartItemProperty, value);
        }

        /// <summary>
        /// 处理键盘事件，支持 Ctrl+A 全选
        /// </summary>
        private static void OnTreeViewKeyDown(object sender, KeyEventArgs e)
        {
            var treeView = sender as TreeView;
            if (treeView == null)
                return;

            // 处理 Ctrl+A 全选
            if (
                e.Key == Key.A
                && (Keyboard.Modifiers & ModifierKeys.Control) == ModifierKeys.Control
            )
            {
                SelectAllItems(treeView);
                e.Handled = true;
            }
        }

        /// <summary>
        /// 全选所有 TreeViewItem
        /// </summary>
        private static void SelectAllItems(TreeView treeView)
        {
            if (treeView == null)
                return;

            var allItems = new List<TreeViewItem>();
            GetAllItems(treeView, null, allItems);

            // 选中所有项
            foreach (var item in allItems)
            {
                SetIsItemSelected(item, true);
            }

            // 设置起始项为第一个项（可选）
            if (allItems.Count > 0)
            {
                SetStartItem(treeView, allItems[0]);
            }
        }

        /// <summary>
        /// 根据不同的键盘修饰键选择项目
        /// </summary>
        private static void SelectItems(TreeViewItem treeViewItem, TreeView treeView)
        {
            if (treeViewItem != null && treeView != null)
            {
                if (
                    (Keyboard.Modifiers & (ModifierKeys.Control | ModifierKeys.Shift))
                    == (ModifierKeys.Control | ModifierKeys.Shift)
                )
                {
                    SelectMultipleItemsContinuously(treeView, treeViewItem, true);
                }
                else if (Keyboard.Modifiers == ModifierKeys.Control)
                {
                    SelectMultipleItemsRandomly(treeView, treeViewItem);
                }
                else if (Keyboard.Modifiers == ModifierKeys.Shift)
                {
                    SelectMultipleItemsContinuously(treeView, treeViewItem);
                }
                else
                {
                    SelectSingleItem(treeView, treeViewItem);
                }
            }
        }

        private static void OnTreeViewItemGotFocus(object sender, RoutedEventArgs e)
        {
            _selectTreeViewItemOnMouseUp = null;

            if (e.OriginalSource is TreeView)
                return;

            var treeViewItem = FindTreeViewItem(e.OriginalSource as DependencyObject);
            if (
                Mouse.LeftButton == MouseButtonState.Pressed
                && GetIsItemSelected(treeViewItem)
                && Keyboard.Modifiers != ModifierKeys.Control
            )
            {
                _selectTreeViewItemOnMouseUp = treeViewItem;
                return;
            }

            SelectItems(treeViewItem, sender as TreeView);
        }

        private static void OnTreeViewItemPreviewMouseDown(object sender, MouseEventArgs e)
        {
            var treeViewItem = FindTreeViewItem(e.OriginalSource as DependencyObject);

            if (treeViewItem != null && treeViewItem.IsFocused)
                OnTreeViewItemGotFocus(sender, e);
        }

        private static void OnTreeViewItemPreviewMouseUp(object sender, MouseButtonEventArgs e)
        {
            var treeViewItem = FindTreeViewItem(e.OriginalSource as DependencyObject);

            if (treeViewItem == _selectTreeViewItemOnMouseUp)
            {
                SelectItems(treeViewItem, sender as TreeView);
            }
        }

        private static TreeViewItem FindTreeViewItem(DependencyObject dependencyObject)
        {
            if (!(dependencyObject is Visual || dependencyObject is Visual3D))
                return null;

            var treeViewItem = dependencyObject as TreeViewItem;
            if (treeViewItem != null)
            {
                return treeViewItem;
            }

            return FindTreeViewItem(VisualTreeHelper.GetParent(dependencyObject));
        }

        private static void SelectSingleItem(TreeView treeView, TreeViewItem treeViewItem)
        {
            // first deselect all items
            DeSelectAllItems(treeView, null);
            SetIsItemSelected(treeViewItem, true);
            SetStartItem(treeView, treeViewItem);
        }

        private static void DeSelectAllItems(TreeView treeView, TreeViewItem treeViewItem)
        {
            if (treeView != null)
            {
                for (int i = 0; i < treeView.Items.Count; i++)
                {
                    var item =
                        treeView.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        SetIsItemSelected(item, false);
                        DeSelectAllItems(null, item);
                    }
                }
            }
            else
            {
                for (int i = 0; i < treeViewItem.Items.Count; i++)
                {
                    TreeViewItem? item =
                        treeViewItem.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        SetIsItemSelected(item, false);
                        DeSelectAllItems(null, item);
                    }
                }
            }
        }

        private static TreeView FindTreeView(DependencyObject dependencyObject)
        {
            if (dependencyObject == null)
            {
                return null;
            }

            var treeView = dependencyObject as TreeView;

            return treeView ?? FindTreeView(VisualTreeHelper.GetParent(dependencyObject));
        }

        private static void SelectMultipleItemsRandomly(
            TreeView treeView,
            TreeViewItem treeViewItem
        )
        {
            SetIsItemSelected(treeViewItem, !GetIsItemSelected(treeViewItem));
            if (GetStartItem(treeView) == null || Keyboard.Modifiers == ModifierKeys.Control)
            {
                if (GetIsItemSelected(treeViewItem))
                {
                    SetStartItem(treeView, treeViewItem);
                }
            }
            else
            {
                if (GetSelectedItems(treeView).Count == 0)
                {
                    SetStartItem(treeView, null);
                }
            }
        }

        private static void SelectMultipleItemsContinuously(
            TreeView treeView,
            TreeViewItem treeViewItem,
            bool shiftControl = false
        )
        {
            TreeViewItem startItem = GetStartItem(treeView);
            if (startItem != null)
            {
                if (startItem == treeViewItem)
                {
                    SelectSingleItem(treeView, treeViewItem);
                    return;
                }

                ICollection<TreeViewItem> allItems = new List<TreeViewItem>();
                GetAllItems(treeView, null, allItems);

                bool isBetween = false;
                foreach (var item in allItems)
                {
                    if (item == treeViewItem || item == startItem)
                    {
                        // toggle to true if first element is found and
                        // back to false if last element is found
                        isBetween = !isBetween;

                        // set boundary element
                        SetIsItemSelected(item, true);
                        continue;
                    }

                    if (isBetween)
                    {
                        SetIsItemSelected(item, true);
                        continue;
                    }

                    if (!shiftControl)
                        SetIsItemSelected(item, false);
                }
            }
        }

        private static void GetAllItems(
            TreeView treeView,
            TreeViewItem treeViewItem,
            ICollection<TreeViewItem> allItems
        )
        {
            if (treeView != null)
            {
                for (int i = 0; i < treeView.Items.Count; i++)
                {
                    var item =
                        treeView.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        allItems.Add(item);
                        GetAllItems(null, item, allItems);
                    }
                }
            }
            else
            {
                for (int i = 0; i < treeViewItem.Items.Count; i++)
                {
                    var item =
                        treeViewItem.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        allItems.Add(item);
                        GetAllItems(null, item, allItems);
                    }
                }
            }
        }

        /// <summary>
        /// 获取所有可见的 TreeViewItem（包括展开的和未展开的）
        /// </summary>
        private static void GetAllVisibleItems(
            TreeView treeView,
            TreeViewItem treeViewItem,
            ICollection<TreeViewItem> allItems
        )
        {
            if (treeView != null)
            {
                for (int i = 0; i < treeView.Items.Count; i++)
                {
                    var item =
                        treeView.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        allItems.Add(item);
                        // 如果项是展开的，递归获取其子项
                        if (item.IsExpanded)
                        {
                            GetAllVisibleItems(null, item, allItems);
                        }
                    }
                }
            }
            else
            {
                for (int i = 0; i < treeViewItem.Items.Count; i++)
                {
                    var item =
                        treeViewItem.ItemContainerGenerator.ContainerFromIndex(i) as TreeViewItem;
                    if (item != null)
                    {
                        allItems.Add(item);
                        // 如果项是展开的，递归获取其子项
                        if (item.IsExpanded)
                        {
                            GetAllVisibleItems(null, item, allItems);
                        }
                    }
                }
            }
        }

        /// <summary>
        /// 全选当前可见项（只选中展开的项）
        /// </summary>
        private static void SelectAllVisibleItems(TreeView treeView)
        {
            if (treeView == null)
                return;

            var visibleItems = new List<TreeViewItem>();
            GetAllVisibleItems(treeView, null, visibleItems);

            // 选中所有可见项
            foreach (var item in visibleItems)
            {
                SetIsItemSelected(item, true);
            }

            // 设置起始项为第一个项
            if (visibleItems.Count > 0)
            {
                SetStartItem(treeView, visibleItems[0]);
            }
        }
    }
}
